"""
Phase 2: Japanification Index Construction
============================================
Construct three index variants:
1. Binary Japanification indicator
2. Continuous Japanification index (standardized, primary)
3. Rolling Japanification index (5-year persistence-weighted)

Also: transition-based risk score, summary statistics.

Input:  japanification/data/processed/japan_panel.csv
Output: japanification/data/processed/japan_panel_indexed.csv
        japanification/output/tables/phase2_*.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"


def main():
    print("=" * 70)
    print("PHASE 2: Japanification Index Construction")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # =================================================================
    # 1. Binary Japanification indicator
    # =================================================================
    print("\n--- Binary Japanification Indicator ---")

    # Country-specific 25th percentile for growth
    g25 = df.groupby('iso3')['rgdp_growth'].transform(lambda x: x.quantile(0.25))

    # Strict: growth < country p25 AND inflation < 2% AND rate < 2%
    df['japan_binary_strict'] = (
        (df['rgdp_growth'] < g25) &
        (df['inflation_japan'] < 2.0) &
        (df['rate_japan'] < 2.0) &
        df['rate_japan'].notna()
    ).astype(int)

    # Robust (no rate requirement): growth < 2% AND inflation < 2%
    df['japan_binary_robust'] = (
        (df['rgdp_growth'] < 2.0) &
        (df['inflation_japan'] < 2.0)
    ).astype(int)

    print(f"  Strict binary: {df['japan_binary_strict'].sum()} events "
          f"({df['japan_binary_strict'].mean()*100:.1f}%)")
    print(f"  Robust binary: {df['japan_binary_robust'].sum()} events "
          f"({df['japan_binary_robust'].mean()*100:.1f}%)")

    # =================================================================
    # 2. Continuous Japanification index
    # =================================================================
    print("\n--- Continuous Japanification Index ---")

    # Standardize each component (panel-wide)
    g_mean, g_std = df['rgdp_growth'].mean(), df['rgdp_growth'].std()
    i_mean, i_std = df['inflation_japan'].mean(), df['inflation_japan'].std()

    df['z_growth'] = (df['rgdp_growth'] - g_mean) / g_std
    df['z_inflation'] = (df['inflation_japan'] - i_mean) / i_std

    # Two-component index (maximum coverage)
    df['japan_index_2c'] = -(df['z_growth'] + df['z_inflation']) / 2.0

    # Three-component (rate-available subsample)
    rate_mask = df['rate_japan'].notna()
    r_mean = df.loc[rate_mask, 'rate_japan'].mean()
    r_std = df.loc[rate_mask, 'rate_japan'].std()
    df['z_rate'] = np.nan
    df.loc[rate_mask, 'z_rate'] = (df.loc[rate_mask, 'rate_japan'] - r_mean) / r_std
    df['japan_index_3c'] = np.nan
    df.loc[rate_mask, 'japan_index_3c'] = -(
        df.loc[rate_mask, 'z_growth'] +
        df.loc[rate_mask, 'z_inflation'] +
        df.loc[rate_mask, 'z_rate']
    ) / 3.0

    # Primary index: use 3-component where available, 2-component otherwise
    df['japan_index'] = df['japan_index_3c'].fillna(df['japan_index_2c'])

    print(f"  2-component: {df['japan_index_2c'].notna().sum():,} obs")
    print(f"  3-component: {df['japan_index_3c'].notna().sum():,} obs")
    print(f"  Primary (combined): {df['japan_index'].notna().sum():,} obs")
    print(f"  Mean={df['japan_index'].mean():.3f}, SD={df['japan_index'].std():.3f}")

    # =================================================================
    # 3. Rolling Japanification index (5-year MA)
    # =================================================================
    print("\n--- Rolling Japanification Index (5yr MA) ---")

    df = df.sort_values(['iso3', 'year'])
    df['japan_index_rolling'] = (
        df.groupby('iso3')['japan_index']
        .transform(lambda x: x.rolling(5, min_periods=3).mean())
    )

    print(f"  Rolling index: {df['japan_index_rolling'].notna().sum():,} obs")

    # =================================================================
    # 4. Japanification risk score (transition probability)
    # =================================================================
    print("\n--- Japanification Risk Score ---")

    # Define "Japanification state" as binary_robust == 1
    # For each country-year, compute P(entering Japanification within 5 years)
    # based on current demographic/macro conditions

    # Forward-looking: did the country enter Japanification in the next 5 years?
    df['japan_state_5yr'] = (
        df.groupby('iso3')['japan_binary_robust']
        .transform(lambda x: x.rolling(5, min_periods=1).max().shift(-4))
    )

    # Simple risk score: rolling average of the forward indicator
    # (approximation — proper model estimated in Phase 3)
    df['japan_risk_score'] = df['japan_state_5yr']

    n_risk = df['japan_risk_score'].notna().sum()
    print(f"  Risk score computed: {n_risk:,} obs")
    print(f"  Countries ever at risk: "
          f"{df.loc[df['japan_risk_score'] == 1, 'iso3'].nunique()}")

    # =================================================================
    # 5. Summary statistics
    # =================================================================
    print(f"\n{'=' * 70}")
    print("INDEX SUMMARY STATISTICS")
    print("=" * 70)

    idx_vars = ['japan_index', 'japan_index_2c', 'japan_index_3c',
                'japan_index_rolling', 'japan_binary_strict', 'japan_binary_robust']
    summary = df[idx_vars].describe().T
    summary['non_missing'] = df[idx_vars].notna().sum()
    print(summary[['non_missing', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]
          .to_string(float_format='%.3f'))

    # Top-10 most Japanified country-years
    print(f"\n--- Top 20 Most Japanified Country-Years (3-component) ---")
    top = (df[df['japan_index_3c'].notna()]
           .nlargest(20, 'japan_index_3c')[['iso3', 'year', 'japan_index_3c',
                                             'rgdp_growth', 'inflation_japan', 'rate_japan']])
    print(top.to_string(index=False, float_format='%.3f'))

    # Country averages
    print(f"\n--- Top 15 Countries by Mean Japanification Index ---")
    country_avg = (df.groupby('iso3')['japan_index']
                   .agg(['mean', 'std', 'count'])
                   .sort_values('mean', ascending=False))
    print(country_avg.head(15).to_string(float_format='%.3f'))

    # =================================================================
    # 6. Save
    # =================================================================
    df.to_csv(PROCESSED_DIR / "japan_panel_indexed.csv", index=False)
    print(f"\nSaved: {PROCESSED_DIR / 'japan_panel_indexed.csv'}")
    print(f"  {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Save summary tables
    summary.to_csv(TABLE_DIR / "phase2_index_summary.csv")
    country_avg.to_csv(TABLE_DIR / "phase2_country_rankings.csv")
    top.to_csv(TABLE_DIR / "phase2_top_japanified.csv", index=False)

    return df


if __name__ == "__main__":
    df = main()
